/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.runners.core;



import com.bff.gaia.unified.model.pipeline.v1.RunnerApi;

import com.bff.gaia.unified.runners.core.construction.PTransformReplacements;

import com.bff.gaia.unified.runners.core.construction.ReplacementOutputs;

import com.bff.gaia.unified.runners.core.construction.SplittableParDo;

import com.bff.gaia.unified.runners.core.construction.PTransformTranslation;

import com.bff.gaia.unified.sdk.coders.ByteArrayCoder;

import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.KvCoder;

import com.bff.gaia.unified.sdk.options.PipelineOptions;

import com.bff.gaia.unified.sdk.runners.AppliedPTransform;

import com.bff.gaia.unified.sdk.runners.PTransformOverrideFactory;

import com.bff.gaia.unified.sdk.state.TimeDomain;

import com.bff.gaia.unified.sdk.state.ValueState;

import com.bff.gaia.unified.sdk.state.WatermarkHoldState;

import com.bff.gaia.unified.sdk.transforms.DoFn;

import com.bff.gaia.unified.sdk.transforms.GroupByKey;

import com.bff.gaia.unified.sdk.transforms.PTransform;

import com.bff.gaia.unified.sdk.transforms.reflect.DoFnInvoker;

import com.bff.gaia.unified.sdk.transforms.reflect.DoFnInvokers;

import com.bff.gaia.unified.sdk.transforms.splittabledofn.RestrictionTracker;

import com.bff.gaia.unified.sdk.transforms.windowing.BoundedWindow;

import com.bff.gaia.unified.sdk.transforms.windowing.GlobalWindow;

import com.bff.gaia.unified.sdk.transforms.windowing.TimestampCombiner;

import com.bff.gaia.unified.sdk.util.WindowedValue;

import com.bff.gaia.unified.vendor.guava.com.google.common.annotations.VisibleForTesting;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.Iterables;

import com.bff.gaia.unified.sdk.values.KV;

import com.bff.gaia.unified.sdk.values.PCollection;

import com.bff.gaia.unified.sdk.values.PCollectionTuple;

import com.bff.gaia.unified.sdk.values.PCollectionView;

import com.bff.gaia.unified.sdk.values.PValue;

import com.bff.gaia.unified.sdk.values.TupleTag;

import com.bff.gaia.unified.sdk.values.TupleTagList;

import com.bff.gaia.unified.sdk.values.WindowingStrategy;

import org.joda.time.Instant;



import javax.annotation.Nullable;

import java.util.List;

import java.util.Map;



/**

 * Utility transforms and overrides for running (typically unbounded) splittable DoFn's with

 * checkpointing by implementing {@link SplittableParDo.ProcessKeyedElements} using {@link KeyedWorkItem} and

 * runner-specific {@link StateInternals} and {@link TimerInternals}.

 *

 * <p>A runner that uses {@link OverrideFactory} will need to also provide runner-specific overrides

 * for {@link GBKIntoKeyedWorkItems} and {@link ProcessElements}.

 */

public class SplittableParDoViaKeyedWorkItems {

  /**

   * Runner-specific primitive {@link GroupByKey GroupByKey-like} {@link PTransform} that produces

   * {@link KeyedWorkItem KeyedWorkItems} so that downstream transforms can access state and timers.

   *

   * <p>Unlike a real {@link GroupByKey}, ignores the input's windowing and triggering strategy and

   * emits output immediately.

   */

  public static class GBKIntoKeyedWorkItems<KeyT, InputT>

      extends PTransformTranslation.RawPTransform<

			PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>> {

    @Override

    public PCollection<KeyedWorkItem<KeyT, InputT>> expand(PCollection<KV<KeyT, InputT>> input) {

      KvCoder<KeyT, InputT> kvCoder = (KvCoder<KeyT, InputT>) input.getCoder();

      return PCollection.createPrimitiveOutputInternal(

          input.getPipeline(),

          WindowingStrategy.globalDefault(),

          input.isBounded(),

          KeyedWorkItemCoder.of(

              kvCoder.getKeyCoder(),

              kvCoder.getValueCoder(),

              input.getWindowingStrategy().getWindowFn().windowCoder()));

    }



    @Override

    public String getUrn() {

      return SplittableParDo.SPLITTABLE_GBKIKWI_URN;

    }



    @Override

    public RunnerApi.FunctionSpec getSpec() {

      throw new UnsupportedOperationException(

          String.format("%s should never be serialized to proto", getClass().getSimpleName()));

    }

  }



  /** Overrides a {@link SplittableParDo.ProcessKeyedElements} into {@link SplittableProcessViaKeyedWorkItems}. */

  public static class OverrideFactory<InputT, OutputT, RestrictionT>

      implements PTransformOverrideFactory<

          PCollection<KV<byte[], KV<InputT, RestrictionT>>>,

	  PCollectionTuple,

	  SplittableParDo.ProcessKeyedElements<InputT, OutputT, RestrictionT>> {

    @Override

    public PTransformReplacement<

            PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple>

        getReplacementTransform(

            AppliedPTransform<

                    PCollection<KV<byte[], KV<InputT, RestrictionT>>>,

                    PCollectionTuple,

				SplittableParDo.ProcessKeyedElements<InputT, OutputT, RestrictionT>>

                transform) {

      return PTransformReplacement.of(

          PTransformReplacements.getSingletonMainInput(transform),

          new SplittableProcessViaKeyedWorkItems<>(transform.getTransform()));

    }



    @Override

    public Map<PValue, ReplacementOutput> mapOutputs(

		Map<TupleTag<?>, PValue> outputs, PCollectionTuple newOutput) {

      return ReplacementOutputs.tagged(outputs, newOutput);

    }

  }



  /**

   * Runner-specific primitive {@link PTransform} that invokes the {@link DoFn.ProcessElement}

   * method for a splittable {@link DoFn}.

   */

  public static class SplittableProcessViaKeyedWorkItems<InputT, OutputT, RestrictionT>

      extends PTransform<PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {

    private final SplittableParDo.ProcessKeyedElements<InputT, OutputT, RestrictionT> original;



    public SplittableProcessViaKeyedWorkItems(

        SplittableParDo.ProcessKeyedElements<InputT, OutputT, RestrictionT> original) {

      this.original = original;

    }



    @Override

    public PCollectionTuple expand(PCollection<KV<byte[], KV<InputT, RestrictionT>>> input) {

      return input

          .apply(new GBKIntoKeyedWorkItems<>())

          .setCoder(

              KeyedWorkItemCoder.of(

                  ByteArrayCoder.of(),

                  ((KvCoder<byte[], KV<InputT, RestrictionT>>) input.getCoder()).getValueCoder(),

                  input.getWindowingStrategy().getWindowFn().windowCoder()))

          .apply(new ProcessElements<>(original));

    }

  }



  /** A primitive transform wrapping around {@link ProcessFn}. */

  public static class ProcessElements<InputT, OutputT, RestrictionT, PositionT>

      extends PTransform<

          PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {

    private final SplittableParDo.ProcessKeyedElements<InputT, OutputT, RestrictionT> original;



    public ProcessElements(SplittableParDo.ProcessKeyedElements<InputT, OutputT, RestrictionT> original) {

      this.original = original;

    }



    public ProcessFn<InputT, OutputT, RestrictionT, PositionT> newProcessFn(

        DoFn<InputT, OutputT> fn) {

      return new ProcessFn<>(

          fn,

          original.getElementCoder(),

          original.getRestrictionCoder(),

          original.getInputWindowingStrategy());

    }



    public DoFn<InputT, OutputT> getFn() {

      return original.getFn();

    }



    public List<PCollectionView<?>> getSideInputs() {

      return original.getSideInputs();

    }



    public TupleTag<OutputT> getMainOutputTag() {

      return original.getMainOutputTag();

    }



    public TupleTagList getAdditionalOutputTags() {

      return original.getAdditionalOutputTags();

    }



    @Override

    public PCollectionTuple expand(

        PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> input) {

      return SplittableParDo.ProcessKeyedElements.createPrimitiveOutputFor(

          input,

          original.getFn(),

          original.getMainOutputTag(),

          original.getAdditionalOutputTags(),

          original.getOutputTagsToCoders(),

          original.getInputWindowingStrategy());

    }

  }



  /**

   * The heart of splittable {@link DoFn} execution: processes a single (element, restriction) pair

   * by creating a tracker for the restriction and checkpointing/resuming processing later if

   * necessary.

   *

   * <p>Takes {@link KeyedWorkItem} and assumes that the KeyedWorkItem contains a single element (or

   * a single timer set by {@link ProcessFn itself}, in a single window. This is necessary because

   * {@link ProcessFn} sets timers, and timers are namespaced to a single window and it should be

   * the window of the input element.

   *

   * <p>See also: https://issues.apache.org/jira/browse/BEAM-1983

   */

  @VisibleForTesting

  public static class ProcessFn<InputT, OutputT, RestrictionT, PositionT>

      extends DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {

    /**

     * The state cell containing a watermark hold for the output of this {@link DoFn}. The hold is

     * acquired during the first {@link DoFn.ProcessElement} call for each element and restriction,

     * and is released when the {@link DoFn.ProcessElement} call returns {@link

     * ProcessContinuation#stop()}.

     *

     * <p>A hold is needed to avoid letting the output watermark immediately progress together with

     * the input watermark when the first {@link DoFn.ProcessElement} call for this element

     * completes.

     */

    private static final StateTag<WatermarkHoldState> watermarkHoldTag =

        StateTags.makeSystemTagInternal(

            StateTags.<GlobalWindow>watermarkStateInternal("hold", TimestampCombiner.LATEST));



    /**

     * The state cell containing a copy of the element. Written during the first {@link

     * DoFn.ProcessElement} call and read during subsequent calls in response to timer firings, when

     * the original element is no longer available.

     */

    private final StateTag<ValueState<WindowedValue<InputT>>> elementTag;



    /**

     * The state cell containing a restriction representing the unprocessed part of work for this

     * element.

     */

    private StateTag<ValueState<RestrictionT>> restrictionTag;



    private final DoFn<InputT, OutputT> fn;

    private final Coder<InputT> elementCoder;

    private final Coder<RestrictionT> restrictionCoder;

    private final WindowingStrategy<InputT, ?> inputWindowingStrategy;



    private transient @Nullable

	StateInternalsFactory<byte[]> stateInternalsFactory;

    private transient @Nullable

	TimerInternalsFactory<byte[]> timerInternalsFactory;

    private transient @Nullable

	SplittableProcessElementInvoker<

            InputT, OutputT, RestrictionT, PositionT>

        processElementInvoker;



    private transient @Nullable

	DoFnInvoker<InputT, OutputT> invoker;



    public ProcessFn(

        DoFn<InputT, OutputT> fn,

        Coder<InputT> elementCoder,

        Coder<RestrictionT> restrictionCoder,

        WindowingStrategy<InputT, ?> inputWindowingStrategy) {

      this.fn = fn;

      this.elementCoder = elementCoder;

      this.restrictionCoder = restrictionCoder;

      this.inputWindowingStrategy = inputWindowingStrategy;

      this.elementTag =

          StateTags.value(

              "element",

              WindowedValue.getFullCoder(

                  elementCoder, inputWindowingStrategy.getWindowFn().windowCoder()));

      this.restrictionTag = StateTags.value("restriction", restrictionCoder);

    }



    public void setStateInternalsFactory(StateInternalsFactory<byte[]> stateInternalsFactory) {

      this.stateInternalsFactory = stateInternalsFactory;

    }



    public void setTimerInternalsFactory(TimerInternalsFactory<byte[]> timerInternalsFactory) {

      this.timerInternalsFactory = timerInternalsFactory;

    }



    public void setProcessElementInvoker(

        SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, PositionT> invoker) {

      this.processElementInvoker = invoker;

    }



    public DoFn<InputT, OutputT> getFn() {

      return fn;

    }



    public Coder<InputT> getElementCoder() {

      return elementCoder;

    }



    public Coder<RestrictionT> getRestrictionCoder() {

      return restrictionCoder;

    }



    public WindowingStrategy<InputT, ?> getInputWindowingStrategy() {

      return inputWindowingStrategy;

    }



    @Setup

    public void setup() throws Exception {

      invoker = DoFnInvokers.invokerFor(fn);

      invoker.invokeSetup();

    }



    @Teardown

    public void tearDown() throws Exception {

      invoker.invokeTeardown();

    }



    @StartBundle

    public void startBundle(StartBundleContext c) throws Exception {

      invoker.invokeStartBundle(wrapContextAsStartBundle(c));

    }



    @FinishBundle

    public void finishBundle(FinishBundleContext c) throws Exception {

      invoker.invokeFinishBundle(wrapContextAsFinishBundle(c));

    }



    /**

     * Processes an element and restriction pair storing the restriction inside of state.

     *

     * <p>Uses a processing timer to resume execution if processing returns a continuation.

     *

     * <p>Uses a watermark hold to control watermark advancement.

     */

    @ProcessElement

    public void processElement(final ProcessContext c) {

      byte[] key = c.element().key();

      StateInternals stateInternals = stateInternalsFactory.stateInternalsForKey(key);

      TimerInternals timerInternals = timerInternalsFactory.timerInternalsForKey(key);



      // Initialize state (element and restriction) depending on whether this is the seed call.

      // The seed call is the first call for this element, which actually has the element.

      // Subsequent calls are timer firings and the element has to be retrieved from the state.

      TimerInternals.TimerData timer = Iterables.getOnlyElement(c.element().timersIterable(), null);

      boolean isSeedCall = (timer == null);

      StateNamespace stateNamespace;

      if (isSeedCall) {

        WindowedValue<KV<InputT, RestrictionT>> windowedValue =

            Iterables.getOnlyElement(c.element().elementsIterable());

        BoundedWindow window = Iterables.getOnlyElement(windowedValue.getWindows());

        stateNamespace =

            StateNamespaces.window(

                (Coder<BoundedWindow>) inputWindowingStrategy.getWindowFn().windowCoder(), window);

      } else {

        stateNamespace = timer.getNamespace();

      }



      ValueState<WindowedValue<InputT>> elementState =

          stateInternals.state(stateNamespace, elementTag);

      ValueState<RestrictionT> restrictionState =

          stateInternals.state(stateNamespace, restrictionTag);

      WatermarkHoldState holdState = stateInternals.state(stateNamespace, watermarkHoldTag);



      KV<WindowedValue<InputT>, RestrictionT> elementAndRestriction;

      if (isSeedCall) {

        WindowedValue<KV<InputT, RestrictionT>> windowedValue =

            Iterables.getOnlyElement(c.element().elementsIterable());

        WindowedValue<InputT> element = windowedValue.withValue(windowedValue.getValue().getKey());

        elementState.write(element);

        elementAndRestriction = KV.of(element, windowedValue.getValue().getValue());

      } else {

        // This is not the first ProcessElement call for this element/restriction - rather,

        // this is a timer firing, so we need to fetch the element and restriction from state.

        elementState.readLater();

        restrictionState.readLater();

        elementAndRestriction = KV.of(elementState.read(), restrictionState.read());

      }



      final RestrictionTracker<RestrictionT, PositionT> tracker =

          invoker.invokeNewTracker(elementAndRestriction.getValue());

      SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, PositionT>.Result result =

          processElementInvoker.invokeProcessElement(

              invoker, elementAndRestriction.getKey(), tracker);



      // Save state for resuming.

      if (result.getResidualRestriction() == null) {

        // All work for this element/restriction is completed. Clear state and release hold.

        elementState.clear();

        restrictionState.clear();

        holdState.clear();

        return;

      }

      restrictionState.write(result.getResidualRestriction());

      @Nullable Instant futureOutputWatermark = result.getFutureOutputWatermark();

      if (futureOutputWatermark == null) {

        futureOutputWatermark = elementAndRestriction.getKey().getTimestamp();

      }

      Instant wakeupTime =

          timerInternals.currentProcessingTime().plus(result.getContinuation().resumeDelay());

      holdState.add(futureOutputWatermark);

      // Set a timer to continue processing this element.

      timerInternals.setTimer(

          TimerInternals.TimerData.of(stateNamespace, wakeupTime, TimeDomain.PROCESSING_TIME));

    }



    private DoFn<InputT, OutputT>.StartBundleContext wrapContextAsStartBundle(

        final StartBundleContext baseContext) {

      return fn.new StartBundleContext() {

        @Override

        public PipelineOptions getPipelineOptions() {

          return baseContext.getPipelineOptions();

        }

      };

    }



    private DoFn<InputT, OutputT>.FinishBundleContext wrapContextAsFinishBundle(

        final FinishBundleContext baseContext) {

      return fn.new FinishBundleContext() {

        @Override

        public void output(OutputT output, Instant timestamp, BoundedWindow window) {

          throwUnsupportedOutput();

        }



        @Override

        public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {

          throwUnsupportedOutput();

        }



        @Override

        public PipelineOptions getPipelineOptions() {

          return baseContext.getPipelineOptions();

        }



        private void throwUnsupportedOutput() {

          throw new UnsupportedOperationException(

              String.format(

                  "Splittable DoFn can only output from @%s",

                  ProcessElement.class.getSimpleName()));

        }

      };

    }

  }

}